#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug  1 16:16:07 2022

Simulation of Distributional Wasserstein Distance
"""

import ot
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def generate_uniform_sphere(d,n,R):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,1,size=(1,d))
        data[j] = R*temp/np.linalg.norm(temp)
    return data


def generate_uniform_ellipsoid(d,n,sigma):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,sigma)
        data[j] = temp/np.sqrt(sum(temp**2/sigma**2))
    return data

    


d = 3
sigma = np.array([2,0.5,1])

#generate a set of simple probability distributions over the unit sphere
prob_set = []
prob_set.append(np.array([[1,0,0],[-1,0,0]]))
for j in range(1,10):
    prob_set.append(generate_uniform_sphere(d,2,1))
    


ns = 1000

xs = np.linspace(-2,2,400)



sample_sizes = [100,500,1000]

m = 1


dsw = np.empty((3,ns))
temp = np.zeros((10,))
rvar = 4/9
xs = np.linspace(-3,3,500)
limSdens = np.exp(-xs**2/(2*rvar))/np.sqrt(2*rvar*np.pi)
n_seed = 20
for n in sample_sizes:
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    for i in range(ns):
        #generate samples from uniform distributions over ellipsoid and unit sphere
        datap = generate_uniform_ellipsoid(d,n,sigma)
        dataq = generate_uniform_sphere(d,n,1)       
        #compute the distributional Wasserstein distance between the empirical distributions
        for j in range(10):
            temp[j] = ot.emd2_1d(np.dot(datap,prob_set[j][0]),np.dot(dataq,prob_set[j][0]),a,b)/3 + \
                ot.emd2_1d(np.dot(datap,prob_set[j][1]),np.dot(dataq,prob_set[j][1]),a,b)*2/3
        dsw[m-1,i] = np.max(temp)
   
    dsw[m-1,:] = np.sqrt(n)*(dsw[m-1,:] - 1/3) 
    density = gaussian_kde(dsw[m-1,:],'silverman')
    plt.figure(m)
    plt.plot(xs,density(xs),color='cadetblue')
    plt.fill_between(xs, density(xs),color='paleturquoise',alpha=0.5)
    plt.plot(xs,limSdens,color='palevioletred')
    plt.fill_between(xs,limSdens,color='pink',alpha=0.5) 
    plt.xlabel("x")
    plt.ylabel("Density")
    plt.title('sample size n = '+str(n))
    m += 1
    
